/*
 * Copyright (c) 2017, Salesforce.com, Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *   list of conditions and the following disclaimer.
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * * Neither the name of the copyright holder nor the names of its
 *   contributors may be used to endorse or promote products derived from
 *   this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package com.salesforce.op

import com.salesforce.op.features.OPFeature
import com.salesforce.op.filters.RawFeatureFilter
import com.salesforce.op.readers.Reader
import com.salesforce.op.stages.OPStage
import com.salesforce.op.stages.impl.preparators.CorrelationType
import com.salesforce.op.stages.impl.selector.ModelSelectorBase
import com.salesforce.op.utils.reflection.ReflectionUtils
import com.salesforce.op.utils.stages.FitStagesUtil
import com.salesforce.op.utils.stages.FitStagesUtil.{CutDAG, FittedDAG, Layer, StagesDAG}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.collection.mutable.{ArrayBuffer, MutableList => MList}
import scala.util.{Failure, Success, Try}

/**
  * Created by allwefantasy on 18/9/2018.
  */
class WowOpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore with Serializable {
  // raw feature filter stage which can be used in place of a reader
  private[op] var rawFeatureFilter: Option[RawFeatureFilter[_]] = None

  /**
    * Set stage and reader parameters from OpWorkflowParams object for run
    *
    * @param newParams new parameter values
    * @return this workflow
    */
  final def setParameters(newParams: OpParams): this.type = {
    parameters = newParams
    if (stages.nonEmpty) setStageParameters(stages)
    this
  }

  override def setStages(value: Array[OPStage]): this.type = {
    stages = value
    this
  }

  def prettyResultFeaturesDependencyGraphs = {
    val buffer = new ArrayBuffer[String]()

    buffer += (s"\nDependency graphs resolved into a stage sequence of:\n{}",
      getStages().map(s =>
        s" ${s.uid}[${s.getInputFeatures().map(_.name).mkString(",")}] --> ${s.getOutputFeatureName}"
      ).mkString("\n")
    )
    buffer += ("*" * 80)
    buffer += ("Result features:")
    resultFeatures.foreach(feature => buffer += (s"${feature.name}:\n${feature.prettyParentStages}"))
    buffer += ("*" * 80)
    buffer.mkString("\n")
  }

  /**
    * This is used to set the stages of the workflow.
    *
    * By setting the final features the stages used to
    * generate them can be traced back through the parent features and origin stages.
    * The input is an tuple of features to support leaf feature generation (multiple endpoints in feature generation).
    *
    * @param features Final features generated by the workflow
    */
  def setResultFeatures(features: OPFeature*): this.type = {
    val featuresArr = features.toArray
    resultFeatures = featuresArr
    rawFeatures = featuresArr.flatMap(_.rawFeatures).distinct.sortBy(_.name)
    checkUnmatchedFeatures()
    setStagesDAG(features = featuresArr)
    validateStages()

    if (log.isInfoEnabled) {
      log.info(s"\nDependency graphs resolved into a stage sequence of:\n{}",
        getStages().map(s =>
          s" ${s.uid}[${s.getInputFeatures().map(_.name).mkString(",")}] --> ${s.getOutputFeatureName}"
        ).mkString("\n")
      )
      log.info("*" * 80)
      log.info("Result features:")
      resultFeatures.foreach(feature => log.info(s"${feature.name}:\n${feature.prettyParentStages}"))
      log.info("*" * 80)
    }
    this
  }

  /**
    * Will set the blacklisted features variable and if list is non-empty it will
    *
    * @param features list of features to blacklist
    */
  private[op] def setBlacklist(features: Array[OPFeature]): Unit = {
    blacklistedFeatures = features
    if (blacklistedFeatures.nonEmpty) {
      val allBlacklisted: MList[OPFeature] = MList(getBlacklist(): _*)
      val allUpdated: MList[OPFeature] = MList.empty

      val initialResultFeatures = getResultFeatures()
      initialResultFeatures
        .foreach { f =>
          if (allBlacklisted.contains(f)) throw new IllegalArgumentException(
            s"Blacklist of features (${allBlacklisted.map(_.name).mkString(", ")})" +
              s" from RawFeatureFilter contained the result feature ${f.name}")
        }

      val initialStages = getStages() // ordered by DAG so dont need to recompute DAG
      // for each stage remove anything blacklisted from the inputs and update any changed input features
      initialStages.foreach { stg =>
        val inFeatures = stg.getInputFeatures()
        val blacklistRemoved = inFeatures.filterNot { f => allBlacklisted.exists(bl => bl.sameOrigin(f)) }
        val inputsChanged = blacklistRemoved.map { f => allUpdated.find(u => u.sameOrigin(f)).getOrElse(f) }
        val oldOutput = stg.getOutput()
        Try {
          stg.setInputFeatureArray(inputsChanged).setOutputFeatureName(oldOutput.name).getOutput()
        } match {
          case Success(out) => allUpdated += out
          case Failure(e) =>
            if (initialResultFeatures.contains(oldOutput)) throw new RuntimeException(
              s"Blacklist of features (${allBlacklisted.map(_.name).mkString(", ")}) \n" +
                s" created by RawFeatureFilter contained features critical to the creation of required result" +
                s" feature (${oldOutput.name}) though the path: \n ${oldOutput.prettyParentStages} \n", e)
            else allBlacklisted += oldOutput
        }
      }

      // Update the whole DAG with the blacklisted features expunged
      val updatedResultFeatures = initialResultFeatures
        .map { f => allUpdated.find(u => u.sameOrigin(f)).getOrElse(f) }
      setResultFeatures(updatedResultFeatures: _*)
    }
  }


  protected[op] def setBlacklistMapKeys(mapKeys: Map[String, Set[String]]): Unit = {
    blacklistedMapKeys = mapKeys
  }

  /**
    * Set parameters from stage params map unless param is set in code.
    * Note: Will NOT override parameter values that have been
    * set previously in code OR with a previous set of parameters.
    */
  private def setStageParameters(stages: Array[OPStage]): Unit = {
    val stageIds = stages.flatMap(s => Seq(s.getClass.getSimpleName, s.uid)).toSet
    val unmatchedStages = parameters.stageParams.keySet.filter(stageIds.contains)
    if (unmatchedStages.nonEmpty) log.error(s"Parameter settings with stage ids: $unmatchedStages had no matching" +
      s"stages in this workflow. Ids for the stages in this workflow are: ${stageIds.mkString(", ")}")
    for {
      (stageName, stageParams) <- parameters.stageParams
      stage <- stages.filter(s => s.getClass.getSimpleName == stageName || s.uid == stageName)
      (k, v) <- stageParams
    } {
      val setStage =
        Try {
          stage.set(stage.getParam(k), v)
        } orElse {
          Try {
            ReflectionUtils.reflectSetterMethod(stage, k, Seq(v))
          }
        }
      if (setStage.isFailure) log.error(
        s"Setting parameter $k with value $v for stage $stage with params ${stage.params.toList} failed with an error",
        setStage.failed.get
      )
      else log.info(s"Set parameter $k to value $v for stage $stage")
    }
  }

  /**
    * Uses input features to reconstruct the DAG of stages needed to generate them
    *
    * @param features final features passed into setInput
    */
  private def setStagesDAG(features: Array[OPFeature]): WowOpWorkflow.this.type = {
    // Unique stages layered by distance
    val uniqueStagesLayered = FitStagesUtil.computeDAG(features)

    if (log.isDebugEnabled) {
      val total = uniqueStagesLayered.map(_.length).sum
      val stages = for {
        layer <- uniqueStagesLayered
        (stage, distance) <- layer
      } yield s"$stage with distance $distance with output ${stage.getOutput().name}"

      log.debug("*" * 80)
      log.debug(s"Setting $total parent stages (sorted by distance desc):\n{}", stages.mkString("\n"))
      log.debug("*" * 80)
    }

    val uniqueStages: Array[OPStage] = uniqueStagesLayered.flatMap(_.map(_._1))

    setStageParameters(uniqueStages)
    setStages(uniqueStages)
  }

  /**
    * Used to generate dataframe from reader and raw features list
    *
    * @return Dataframe with all the features generated + persisted
    */
  protected def generateRawData()(implicit spark: SparkSession): DataFrame = {
    (reader, rawFeatureFilter) match {
      case (None, None) => throw new IllegalArgumentException("Data reader must be set either directly on the" +
        " workflow or through the RawFeatureFilter")
      case (Some(r), None) =>
        checkReadersAndFeatures()
        r.generateDataFrame(rawFeatures, parameters).persist()
      case (rd, Some(rf)) =>
        rd match {
          case None => setReader(rf.trainingReader)
          case Some(r) => if (r != rf.trainingReader) log.warn("Workflow data reader and RawFeatureFilter training" +
            " reader do not match! The RawFeatureFilter training reader will be used to generate the data for training")
        }
        checkReadersAndFeatures()
        val filteredRawData = rf.generateFilteredRaw(rawFeatures, parameters)
        setBlacklist(filteredRawData.featuresToDrop)
        setBlacklistMapKeys(filteredRawData.mapKeysToDrop)
        filteredRawData.cleanedData
    }
  }

  /**
    * Transform function for testing chained transformations
    *
    * @param in DataFrame
    * @return transformed DataFrame
    */
  private[op] def transform(in: DataFrame, persistEveryKStages: Int = OpWorkflowModel.PersistEveryKStages)
                           (implicit sc: SparkSession): DataFrame = {
    val transformers = fitStages(in, stages, persistEveryKStages).map(_.asInstanceOf[Transformer])
    FitStagesUtil.applySparkTransformations(in, transformers, persistEveryKStages)
  }

  /**
    * Check if all the stages of the workflow are serializable
    *
    * @return Failure if not serializable
    */
  private[op] def checkSerializable(): Try[Unit] = Try {
    val failures = stages.map(s => s.uid -> s.checkSerializable).collect { case (stageId, Failure(e)) => stageId -> e }

    if (failures.nonEmpty) throw new IllegalArgumentException(
      s"All stages must be serializable. Failed stages: ${failures.map(_._1).mkString(",")}",
      failures.head._2
    )
  }

  /**
    * Check if all the stages of the workflow have uid argument in their constructors
    * (required for workflow save/load to work)
    *
    * @return Failure if there is at least one stage without a uid argument in constructor
    */
  private[op] def checkCtorUIDs(): Try[Unit] = checkCtorArgs(arg = "uid")

  /**
    * Check if all the stages of the workflow have a specified argument 'arg' in their constructors
    * (required for workflow save/load to work)
    *
    * @param arg ctor argument to check
    * @return Failure if there is at least one stage without a 'arg' argument in constructor
    */
  private[op] def checkCtorArgs(arg: String): Try[Unit] = Try {
    val failures =
      stages.map(s => s.uid -> ReflectionUtils.bestCtorWithArgs(s)._2.map(_._1))
        .collect { case (stageId, args) if !args.contains(arg) => stageId }

    if (failures.nonEmpty) throw new IllegalArgumentException(
      s"All stages must be have $arg as their ctor argument. Failed stages: ${failures.mkString(",")}"
    )
  }

  /**
    * Check if all the stages of the workflow have uid argument in their constructors
    * (required for workflow save/load to work)
    *
    * @return Failure if there is at least one stage without a uid argument in constructor
    */
  private[op] def checkDistinctUIDs(): Try[Unit] = Try {
    if (stages.map(_.uid).distinct.length != stages.length) throw new IllegalArgumentException(
      "All stages must be distinct instances with distinct uids for saving"
    )
  }

  /**
    * Validate all the workflow stages
    *
    * @throws IllegalArgumentException
    */
  private[op] def validateStages(): Unit = {
    val res = for {
      _ <- checkCtorUIDs()
      _ <- checkSerializable()
      _ <- checkDistinctUIDs()
    } yield ()
    if (res.isFailure) throw res.failed.get
  }

  /**
    * Fit all of the estimators in the pipeline and return a pipeline model of only transformers. Uses data loaded
    * as specified by the data reader to generate the initial data set.
    *
    * @param persistEveryKStages persist data in transforms every k stages for performance improvement
    * @return a fitted pipeline model
    */
  def train(persistEveryKStages: Int = OpWorkflowModel.PersistEveryKStages)
           (implicit spark: SparkSession): OpWorkflowModel = {

    val (fittedStages, newResultFeatures) =
      if (stages.exists(_.isInstanceOf[Estimator[_]])) {
        val rawData = generateRawData()

        // Update features with fitted stages
        val fittedStgs = fitStages(data = rawData, stagesToFit = stages, persistEveryKStages)
        val newResultFtrs = resultFeatures.map(_.copyWithNewStages(fittedStgs))
        fittedStgs -> newResultFtrs
      } else {
        stages -> resultFeatures
      }

    val model =
      new OpWorkflowModel(uid, getParameters())
        .setStages(fittedStages)
        .setFeatures(newResultFeatures)
        .setParameters(getParameters())
        .setBlacklist(getBlacklist())
        .setBlacklistMapKeys(getBlacklistMapKeys())

    reader.map(model.setReader).getOrElse(model)
  }


  def trainFeatureModel(persistEveryKStages: Int = OpWorkflowModel.PersistEveryKStages)
                       (implicit spark: SparkSession): WowOpWorkflowModel = {

    val (fittedStages, newResultFeatures) =
      if (stages.exists(_.isInstanceOf[Estimator[_]])) {
        val rawData = generateRawData()

        // Update features with fitted stages
        val fittedStgs = fitFeatureStages(data = rawData, stagesToFit = stages, persistEveryKStages)
        val newResultFtrs = resultFeatures.map(_.copyWithNewStages(fittedStgs))
        fittedStgs -> newResultFtrs
      } else {
        stages -> resultFeatures
      }

    val model =
      new WowOpWorkflowModel(uid, getParameters())
        .setStages(fittedStages)
        .setFeatures(newResultFeatures)
        .setParameters(getParameters())
        .setBlacklist(getBlacklist())
        .setBlacklistMapKeys(getBlacklistMapKeys())

    reader.map(model.setReader).getOrElse(model)
  }


  /**
    * Fit the estimators to return a sequence of only transformers
    * Modified version of Spark 2.x Pipeline
    *
    * @param data                dataframe to fit on
    * @param stagesToFit         stages that need to be converted to transformers
    * @param persistEveryKStages persist data in transforms every k stages for performance improvement
    * @return fitted transformers
    */
  protected def fitFeatureStages(data: DataFrame, stagesToFit: Array[OPStage], persistEveryKStages: Int)
                                (implicit spark: SparkSession): Array[OPStage] = {

    // TODO may want to make workflow take an optional reserve fraction
    val splitters = stagesToFit.collect { case s: ModelSelectorBase[_, _] => s.splitter }.flatten
    val splitter = splitters.reduceOption { (a, b) =>
      if (a.getReserveTestFraction > b.getReserveTestFraction) a else b
    }
    val (train, test) = splitter.map(_.split(data)).getOrElse((data, spark.emptyDataFrame))
    val hasTest = !test.isEmpty

    val dag = FitStagesUtil.computeDAG(resultFeatures)
      .map(_.filter(s => stagesToFit.contains(s._1)))
      .filter(_.nonEmpty)

    // Search for the last estimator
    val indexOfLastEstimator: Option[Int] =
      dag.collect { case seq if seq.exists(_._1.isInstanceOf[Estimator[_]]) => seq.head._2 }.lastOption

    // doing regular workflow fit without workflow level CV
    if (!isWorkflowCV) {
      FitStagesUtil.fitAndTransformDAG(
        dag = dag,
        train = train,
        test = test,
        hasTest = hasTest,
        indexOfLastEstimator = indexOfLastEstimator,
        persistEveryKStages = persistEveryKStages
      ).transformers
    } else {
      // doing workflow level CV/TS
      // Extract Model Selector and Split the DAG into
      val CutDAG(modelSelectorOpt, before, during, after) = FitStagesUtil.cutDAG(dag)

      log.info("Applying initial DAG before CV/TS. Stages: {}", before.flatMap(_.map(_._1.stageName)).mkString(", "))
      val FittedDAG(beforeTrain, beforeTest, beforeTransformers) = FitStagesUtil.fitAndTransformDAG(
        dag = before,
        train = train,
        test = test,
        hasTest = hasTest,
        indexOfLastEstimator = indexOfLastEstimator,
        persistEveryKStages = persistEveryKStages
      )

      // Break up catalyst (cause it chokes) by converting into rdd, persisting it and then back to dataframe
      val (trainRDD, testRDD) = (beforeTrain.rdd.persist(), beforeTest.rdd.persist())
      val (trainFixed, testFixed) = (
        spark.createDataFrame(trainRDD, beforeTrain.schema),
        spark.createDataFrame(testRDD, beforeTest.schema)
      )

      modelSelectorOpt match {
        case None => beforeTransformers
        case Some((modelSelector, distance)) =>
          // estimate best model
          log.info("Estimate best Model with CV/TS. Stages included in CV are: {}, {}",
            during.flatMap(_.map(_._1.stageName)).mkString(", "), modelSelector.uid: Any
          )
          modelSelector.findBestEstimator(trainFixed, during, persistEveryKStages)
          val remainingDAG: StagesDAG = (during :+ (Array(modelSelector -> distance): Layer)) ++ after

          log.info("Applying DAG after CV/TS. Stages: {}", remainingDAG.flatMap(_.map(_._1.stageName)).mkString(", "))
          val fitted = FitStagesUtil.fitAndTransformDAG(
            dag = remainingDAG,
            train = trainFixed,
            test = testFixed,
            hasTest = hasTest,
            indexOfLastEstimator = indexOfLastEstimator,
            persistEveryKStages = persistEveryKStages,
            fittedTransformers = beforeTransformers
          ).transformers
          trainRDD.unpersist()
          testRDD.unpersist()
          fitted
      }
    }
  }

  /**
    * Fit the estimators to return a sequence of only transformers
    * Modified version of Spark 2.x Pipeline
    *
    * @param data                dataframe to fit on
    * @param stagesToFit         stages that need to be converted to transformers
    * @param persistEveryKStages persist data in transforms every k stages for performance improvement
    * @return fitted transformers
    */
  protected def fitStages(data: DataFrame, stagesToFit: Array[OPStage], persistEveryKStages: Int)
                         (implicit spark: SparkSession): Array[OPStage] = {

    // TODO may want to make workflow take an optional reserve fraction
    val splitters = stagesToFit.collect { case s: ModelSelectorBase[_, _] => s.splitter }.flatten
    val splitter = splitters.reduceOption { (a, b) =>
      if (a.getReserveTestFraction > b.getReserveTestFraction) a else b
    }
    val (train, test) = splitter.map(_.split(data)).getOrElse((data, spark.emptyDataFrame))
    val hasTest = !test.isEmpty

    val dag = FitStagesUtil.computeDAG(resultFeatures)
      .map(_.filter(s => stagesToFit.contains(s._1)))
      .filter(_.nonEmpty)

    // Search for the last estimator
    val indexOfLastEstimator: Option[Int] =
      dag.collect { case seq if seq.exists(_._1.isInstanceOf[Estimator[_]]) => seq.head._2 }.lastOption

    // doing regular workflow fit without workflow level CV
    if (!isWorkflowCV) {
      FitStagesUtil.fitAndTransformDAG(
        dag = dag,
        train = train,
        test = test,
        hasTest = hasTest,
        indexOfLastEstimator = indexOfLastEstimator,
        persistEveryKStages = persistEveryKStages
      ).transformers
    } else {
      // doing workflow level CV/TS
      // Extract Model Selector and Split the DAG into
      val CutDAG(modelSelectorOpt, before, during, after) = FitStagesUtil.cutDAG(dag)

      log.info("Applying initial DAG before CV/TS. Stages: {}", before.flatMap(_.map(_._1.stageName)).mkString(", "))
      val FittedDAG(beforeTrain, beforeTest, beforeTransformers) = FitStagesUtil.fitAndTransformDAG(
        dag = before,
        train = train,
        test = test,
        hasTest = hasTest,
        indexOfLastEstimator = indexOfLastEstimator,
        persistEveryKStages = persistEveryKStages
      )

      // Break up catalyst (cause it chokes) by converting into rdd, persisting it and then back to dataframe
      val (trainRDD, testRDD) = (beforeTrain.rdd.persist(), beforeTest.rdd.persist())
      val (trainFixed, testFixed) = (
        spark.createDataFrame(trainRDD, beforeTrain.schema),
        spark.createDataFrame(testRDD, beforeTest.schema)
      )

      modelSelectorOpt match {
        case None => beforeTransformers
        case Some((modelSelector, distance)) =>
          // estimate best model
          log.info("Estimate best Model with CV/TS. Stages included in CV are: {}, {}",
            during.flatMap(_.map(_._1.stageName)).mkString(", "), modelSelector.uid: Any
          )
          modelSelector.findBestEstimator(trainFixed, during, persistEveryKStages)
          val remainingDAG: StagesDAG = (during :+ (Array(modelSelector -> distance): Layer)) ++ after

          log.info("Applying DAG after CV/TS. Stages: {}", remainingDAG.flatMap(_.map(_._1.stageName)).mkString(", "))
          val fitted = FitStagesUtil.fitAndTransformDAG(
            dag = remainingDAG,
            train = trainFixed,
            test = testFixed,
            hasTest = hasTest,
            indexOfLastEstimator = indexOfLastEstimator,
            persistEveryKStages = persistEveryKStages,
            fittedTransformers = beforeTransformers
          ).transformers
          trainRDD.unpersist()
          testRDD.unpersist()
          fitted
      }
    }
  }

  /**
    * Replaces any estimators in this workflow with their corresponding fit models from the OpWorkflowModel
    * passed in. Note that the Stages UIDs must EXACTLY correspond in order to be replaced so the same features
    * and stages must be used in both the fitted OpWorkflowModel and this OpWorkflow.
    * Any estimators that are not part of the OpWorkflowModel passed in will be trained when .train()
    * is called on this OpWorkflow.
    *
    * @param model model containing fitted stages to be used in this workflow
    * @return an OpWorkflow containing all of the stages from this model plus any new stages
    *         needed to generate the features not included in the fitted model
    */
  def withModelStages(model: OpWorkflowModel): this.type = {
    val newResultFeatures = (resultFeatures ++ model.getResultFeatures()).map(_.copyWithNewStages(model.stages))
    setResultFeatures(newResultFeatures: _*)
  }

  /**
    * Load a previously trained workflow model from path
    *
    * @param path to the trained workflow model
    * @return workflow model
    */
  def loadModel(path: String): WowOpWorkflowModel = new WowOpWorkflowModelReader(this).load(path)

  /**
    * Returns a dataframe containing all the columns generated up to and including the feature input
    *
    * @param feature             input feature to compute up to
    * @param persistEveryKStages persist data in transforms every k stages for performance improvement
    * @return Dataframe containing columns corresponding to all of the features generated up to the feature given
    */
  def computeDataUpTo(feature: OPFeature, persistEveryKStages: Int = OpWorkflowModel.PersistEveryKStages)
                     (implicit spark: SparkSession): DataFrame = {
    if (findOriginStageId(feature).isEmpty) {
      log.warn("Could not find origin stage for feature in workflow!! Defaulting to generate raw features.")
      generateRawData()
    } else {
      val rawData = generateRawData()
      val stagesToFit = FitStagesUtil.computeDAG(Array(feature)).flatMap(_.map(_._1))
      val fittedStages = fitStages(rawData, stagesToFit, persistEveryKStages)
      val updatedFeature = feature.copyWithNewStages(fittedStages)
      val dag = FitStagesUtil.computeDAG(Array(updatedFeature))
      applyTransformationsDAG(rawData, dag, persistEveryKStages)
    }
  }

  /**
    * Add a raw features filter to the workflow to look at fill rates and distributions of raw features and exclude
    * features that do not meet specifications from modeling DAG
    *
    * @param trainingReader    training reader to use in filter if not suplied will fall back to reader specified for
    *                          workflow (note that this reader will take precidence over readers directly input to the
    *                          workflow if both are supplied)
    * @param scoringReader     scoring reader to use in filter if not supplied will do the checks possible with only
    *                          training data avaialable
    * @param bins              number of bins to use in estimating feature distributions
    * @param minFillRate       minimum non-null fraction of instances that a feature should contain
    * @param maxFillDifference maximum absolute difference in fill rate between scoring and training data for a feature
    * @param maxFillRatioDiff  maximum difference in fill ratio (symetric) between scoring and training data for a feature
    * @param maxJSDivergence   maximum Jensen-Shannon divergence between the training and scoring distributions
    *                          for a feature
    * @param protectedFeatures list of features that should never be removed (features that are used to create them will
    *                          also be protected)
    * @tparam T Type of the data read in
    */
  @Experimental
  def withRawFeatureFilter[T](
                               trainingReader: Option[Reader[T]],
                               scoringReader: Option[Reader[T]],
                               bins: Int = 100,
                               minFillRate: Double = 0.001,
                               maxFillDifference: Double = 0.90,
                               maxFillRatioDiff: Double = 20.0,
                               maxJSDivergence: Double = 0.90,
                               maxCorrelation: Double = 0.95,
                               correlationType: CorrelationType = CorrelationType.Pearson,
                               protectedFeatures: Array[OPFeature] = Array.empty
                             ): this.type = {
    val training = trainingReader.orElse(reader).map(_.asInstanceOf[Reader[T]])
    require(training.nonEmpty, "Reader for training data must be provided either in withRawFeatureFilter or directly" +
      "as the reader for the workflow")
    val protectedRawFeatures = protectedFeatures.flatMap(_.rawFeatures).map(_.name).toSet
    rawFeatureFilter = Option {
      new RawFeatureFilter(
        trainingReader = training.get,
        scoreReader = scoringReader,
        bins = bins,
        minFill = minFillRate,
        maxFillDifference = maxFillDifference,
        maxFillRatioDiff = maxFillRatioDiff,
        maxJSDivergence = maxJSDivergence,
        maxCorrelation = maxCorrelation,
        correlationType = correlationType,
        protectedFeatures = protectedRawFeatures)
    }
    this
  }
}
